{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "id": "FPjTnCG-v2Z6",
    "outputId": "f83c55eb-eba5-44ed-830d-498e7d3abc73"
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "zsh:1: command not found: nvidia-smi\n"
     ]
    }
   ],
   "source": [
    "\"\"\"\n",
    "Notebook for training the embedding model for the flow around cylinder system.\n",
    "=====\n",
    "Distributed by: Notre Dame SCAI Lab (MIT Liscense)\n",
    "- Associated publication:\n",
    "url: https://arxiv.org/abs/2010.03957\n",
    "doi: \n",
    "github: https://github.com/zabaras/transformer-physx\n",
    "=====\n",
    "\"\"\"\n",
    "!nvidia-smi"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "34XMtg9FZFql"
   },
   "source": [
    "# Environment Setup"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "id": "9zaXL-m8xEx9",
    "outputId": "c8a3a28e-2d70-4cd8-94b2-b3c87c305354"
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\u001b[1m\u001b[36mdata\u001b[m\u001b[m                             train_cylinder_enn.py\n",
      "\u001b[1m\u001b[36moutputs\u001b[m\u001b[m                          train_cylinder_transformer.ipynb\n",
      "train_cylinder_enn.ipynb         train_cylinder_transformer.py\n"
     ]
    }
   ],
   "source": [
    "!ls"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "6MK_h8wF0Rr4"
   },
   "source": [
    "Now lets download the training and validation data for the cylinder system. This will eventually be update to zenodo repo."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {
    "id": "2NtZ02zD0EKo"
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "mkdir: data: File exists\n"
     ]
    }
   ],
   "source": [
    "!mkdir data"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "G4TdTWCNNLOY"
   },
   "source": [
    "**WARNING: Training datafile is 1.3Gb! Validation is 0.35Gb!** These will be stored in your Google drive!\n",
    "\n",
    "Because this is a big datafile, we will use gdown to instead of wget because of virus warning from google drive."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "dCHiYnrdZN95"
   },
   "source": [
    "# Transformer-PhysX Cylinder System"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "YDVeeJHn11Ir"
   },
   "source": [
    "Train the embedding model.\n",
    "First import necessary modules from trphysx. "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {
    "id": "dNGVZQ-o1gsH"
   },
   "outputs": [],
   "source": [
    "import sys\n",
    "import logging\n",
    "\n",
    "import paddle\n",
    "from paddle.optimizer.lr import ExponentialDecay\n",
    "\n",
    "from trphysx.config.configuration_auto import AutoPhysConfig\n",
    "from trphysx.embedding.embedding_auto import AutoEmbeddingModel\n",
    "from trphysx.viz.viz_auto import AutoViz\n",
    "from trphysx.embedding.training import *\n",
    "\n",
    "logger = logging.getLogger(__name__)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "YEft0ltg4swx"
   },
   "source": [
    "Training arguments."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {
    "id": "8R8QQ0cj4qR9"
   },
   "outputs": [],
   "source": [
    "argv = []\n",
    "argv = argv + [\"--exp_name\", \"cylinder\"]\n",
    "argv = argv + [\"--training_h5_file\", \"./data/cylinder_train.hdf5\"]\n",
    "argv = argv + [\"--eval_h5_file\", \"./data/cylinder_valid.hdf5\"]\n",
    "argv = argv + [\"--batch_size\", \"32\"]\n",
    "argv = argv + [\"--block_size\", \"4\"]\n",
    "argv = argv + [\"--n_train\", \"27\"]\n",
    "argv = argv + [\"--n_eval\", \"6\"]\n",
    "argv = argv + [\"--epochs\", \"100\"]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "id": "wcBfHEh540f6",
    "outputId": "fb5c7820-ff16-473f-e1e7-436cb4e41521"
   },
   "outputs": [],
   "source": [
    "args = EmbeddingParser().parse(args=argv)\n",
    "\n",
    "# Setup logging\n",
    "logging.basicConfig(\n",
    "    format=\"%(asctime)s - %(levelname)s - %(name)s -   %(message)s\",\n",
    "    datefmt=\"%m/%d/%Y %H:%M:%S\",\n",
    "    level=logging.INFO,\n",
    ")\n",
    "\n",
    "if paddle.device.is_compiled_with_cuda():\n",
    "    use_cuda = \"gpu:0\"\n",
    "else:\n",
    "    use_cuda = \"cpu\"\n",
    "args.device = paddle.set_device(use_cuda)\n",
    "logger.info(\"Torch device: {}\".format(args.device))\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "T9Fel9vqZVsH"
   },
   "source": [
    "## Initializing Datasets and Models"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "JjxiqD3lF98_"
   },
   "source": [
    "Now we can use the auto classes to initialized the predefined configs, dataloaders and models. This may take a bit!"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "id": "EcpC9Fy243RN",
    "outputId": "9cab7442-f19d-4408-c119-2d079c432ba8"
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "[2023-01-19 21:29:08,822] [ WARNING] enn_data_handler.py:425 - Lower batch-size to 6\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "embedding_model.observableNet.0.weight paddle.float32 False\n",
      "embedding_model.observableNet.0.bias paddle.float32 False\n",
      "embedding_model.observableNet.2.weight paddle.float32 False\n",
      "embedding_model.observableNet.2.bias paddle.float32 False\n",
      "embedding_model.observableNet.4.weight paddle.float32 False\n",
      "embedding_model.observableNet.4.bias paddle.float32 False\n",
      "embedding_model.observableNet.6.weight paddle.float32 False\n",
      "embedding_model.observableNet.6.bias paddle.float32 False\n",
      "embedding_model.observableNet.8.weight paddle.float32 False\n",
      "embedding_model.observableNet.8.bias paddle.float32 False\n",
      "embedding_model.observableNetFC.0.weight paddle.float32 False\n",
      "embedding_model.observableNetFC.0.bias paddle.float32 False\n",
      "embedding_model.recoveryNet.1.weight paddle.float32 False\n",
      "embedding_model.recoveryNet.1.bias paddle.float32 False\n",
      "embedding_model.recoveryNet.4.weight paddle.float32 False\n",
      "embedding_model.recoveryNet.4.bias paddle.float32 False\n",
      "embedding_model.recoveryNet.7.weight paddle.float32 False\n",
      "embedding_model.recoveryNet.7.bias paddle.float32 False\n",
      "embedding_model.recoveryNet.10.weight paddle.float32 False\n",
      "embedding_model.recoveryNet.10.bias paddle.float32 False\n",
      "embedding_model.recoveryNet.12.weight paddle.float32 False\n",
      "embedding_model.recoveryNet.12.bias paddle.float32 False\n",
      "embedding_model.kMatrixDiagNet.0.weight paddle.float32 False\n",
      "embedding_model.kMatrixDiagNet.0.bias paddle.float32 False\n",
      "embedding_model.kMatrixDiagNet.2.weight paddle.float32 False\n",
      "embedding_model.kMatrixDiagNet.2.bias paddle.float32 False\n",
      "embedding_model.kMatrixUT.0.weight paddle.float32 False\n",
      "embedding_model.kMatrixUT.0.bias paddle.float32 False\n",
      "embedding_model.kMatrixUT.2.weight paddle.float32 False\n",
      "embedding_model.kMatrixUT.2.bias paddle.float32 False\n",
      "embedding_model.kMatrixLT.0.weight paddle.float32 False\n",
      "embedding_model.kMatrixLT.0.bias paddle.float32 False\n",
      "embedding_model.kMatrixLT.2.weight paddle.float32 False\n",
      "embedding_model.kMatrixLT.2.bias paddle.float32 False\n"
     ]
    }
   ],
   "source": [
    "# Load transformer config file\n",
    "config = AutoPhysConfig.load_config(args.exp_name)\n",
    "data_handler = AutoDataHandler.load_data_handler(args.exp_name)\n",
    "viz = AutoViz.load_viz(args.exp_name, plot_dir=args.plot_dir)\n",
    "\n",
    "# Set up data-loaders\n",
    "training_loader = data_handler.createTrainingLoader(\n",
    "    args.training_h5_file,\n",
    "    block_size=args.block_size,\n",
    "    stride=args.stride,\n",
    "    ndata=args.n_train,\n",
    "    batch_size=args.batch_size,\n",
    ")\n",
    "testing_loader = data_handler.createTestingLoader(\n",
    "    args.eval_h5_file, block_size=32, ndata=args.n_eval, batch_size=8\n",
    ")\n",
    "\n",
    "# Set up model\n",
    "model = AutoEmbeddingModel.init_trainer(args.exp_name, config).to(args.device)\n",
    "mu, std = data_handler.norm_params\n",
    "# model.embedding_model.mu = mu.to(args.device)\n",
    "# model.embedding_model.std = std.to(args.device)\n",
    "if args.epoch_start > 1:\n",
    "    model.load_model(args.ckpt_dir, args.epoch_start)\n",
    "count = 0\n",
    "for name, param in model.named_parameters():\n",
    "    print(name, param.dtype,param.stop_gradient)\n",
    "    count += param.numel()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "auDiMVZ5UfNz"
   },
   "source": [
    "Initialize optimizer and scheduler. Feel free to change if you want to experiment."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {
    "id": "bnrtuKdhGuWQ"
   },
   "outputs": [],
   "source": [
    "scheduler = ExponentialDecay(\n",
    "    learning_rate=args.lr * 0.995 ** (args.epoch_start - 1), gamma=0.995\n",
    ")\n",
    "optimizer = paddle.optimizer.Adam(\n",
    "    parameters=model.parameters(),\n",
    "    learning_rate=scheduler,\n",
    "    weight_decay=1e-8,\n",
    "    \n",
    ")\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "It_LLGQIZe0b"
   },
   "source": [
    "## Training the Embedding Model"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "U2XPKfYTUuXf"
   },
   "source": [
    "This notebook only trains for 100 epochs for brevity, feel free to train longer. The test loss here is only the recovery loss MSE(x - decode(encode(x))) and does not reflect the quality of the Koopman dynamics."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "id": "2ic9DFQcUWpm",
    "outputId": "8ca81463-f225-4ef5-b8a3-859992658663"
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/Users/puqing/opt/miniconda3/lib/python3.9/site-packages/astroid/node_classes.py:90: DeprecationWarning: The 'astroid.node_classes' module is deprecated and will be replaced by 'astroid.nodes' in astroid 3.0.0\n",
      "  warnings.warn(\n"
     ]
    },
    {
     "ename": "KeyboardInterrupt",
     "evalue": "",
     "output_type": "error",
     "traceback": [
      "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
      "\u001b[0;31mKeyboardInterrupt\u001b[0m                         Traceback (most recent call last)",
      "Cell \u001b[0;32mIn [6], line 2\u001b[0m\n\u001b[1;32m      1\u001b[0m trainer \u001b[39m=\u001b[39m EmbeddingTrainer(model, args, (optimizer, scheduler), viz)\n\u001b[0;32m----> 2\u001b[0m trainer\u001b[39m.\u001b[39;49mtrain(training_loader, testing_loader)\n",
      "File \u001b[0;32m~/opt/miniconda3/lib/python3.9/site-packages/trphysx-0.0.8-py3.9.egg/trphysx/embedding/training/enn_trainer.py:115\u001b[0m, in \u001b[0;36mEmbeddingTrainer.train\u001b[0;34m(self, training_loader, eval_dataloader)\u001b[0m\n\u001b[1;32m    113\u001b[0m loss_total \u001b[39m=\u001b[39m loss_total \u001b[39m+\u001b[39m loss0\u001b[39m.\u001b[39mdetach()\n\u001b[1;32m    114\u001b[0m \u001b[39m# Backwards!\u001b[39;00m\n\u001b[0;32m--> 115\u001b[0m loss0\u001b[39m.\u001b[39;49mbackward()\n\u001b[1;32m    116\u001b[0m optimizer\u001b[39m.\u001b[39mstep()\n\u001b[1;32m    117\u001b[0m optimizer\u001b[39m.\u001b[39mclear_grad()\n",
      "File \u001b[0;32m~/opt/miniconda3/lib/python3.9/site-packages/decorator.py:232\u001b[0m, in \u001b[0;36mdecorate.<locals>.fun\u001b[0;34m(*args, **kw)\u001b[0m\n\u001b[1;32m    230\u001b[0m \u001b[39mif\u001b[39;00m \u001b[39mnot\u001b[39;00m kwsyntax:\n\u001b[1;32m    231\u001b[0m     args, kw \u001b[39m=\u001b[39m fix(args, kw, sig)\n\u001b[0;32m--> 232\u001b[0m \u001b[39mreturn\u001b[39;00m caller(func, \u001b[39m*\u001b[39;49m(extras \u001b[39m+\u001b[39;49m args), \u001b[39m*\u001b[39;49m\u001b[39m*\u001b[39;49mkw)\n",
      "File \u001b[0;32m~/opt/miniconda3/lib/python3.9/site-packages/paddle/fluid/wrapped_decorator.py:26\u001b[0m, in \u001b[0;36mwrap_decorator.<locals>.__impl__\u001b[0;34m(func, *args, **kwargs)\u001b[0m\n\u001b[1;32m     23\u001b[0m \u001b[39m@decorator\u001b[39m\u001b[39m.\u001b[39mdecorator\n\u001b[1;32m     24\u001b[0m \u001b[39mdef\u001b[39;00m \u001b[39m__impl__\u001b[39m(func, \u001b[39m*\u001b[39margs, \u001b[39m*\u001b[39m\u001b[39m*\u001b[39mkwargs):\n\u001b[1;32m     25\u001b[0m     wrapped_func \u001b[39m=\u001b[39m decorator_func(func)\n\u001b[0;32m---> 26\u001b[0m     \u001b[39mreturn\u001b[39;00m wrapped_func(\u001b[39m*\u001b[39;49margs, \u001b[39m*\u001b[39;49m\u001b[39m*\u001b[39;49mkwargs)\n",
      "File \u001b[0;32m~/opt/miniconda3/lib/python3.9/site-packages/paddle/fluid/framework.py:534\u001b[0m, in \u001b[0;36m_dygraph_only_.<locals>.__impl__\u001b[0;34m(*args, **kwargs)\u001b[0m\n\u001b[1;32m    529\u001b[0m \u001b[39mdef\u001b[39;00m \u001b[39m__impl__\u001b[39m(\u001b[39m*\u001b[39margs, \u001b[39m*\u001b[39m\u001b[39m*\u001b[39mkwargs):\n\u001b[1;32m    530\u001b[0m     \u001b[39massert\u001b[39;00m _non_static_mode(), (\n\u001b[1;32m    531\u001b[0m         \u001b[39m\"\u001b[39m\u001b[39mWe only support \u001b[39m\u001b[39m'\u001b[39m\u001b[39m%s\u001b[39;00m\u001b[39m()\u001b[39m\u001b[39m'\u001b[39m\u001b[39m in dynamic graph mode, please call \u001b[39m\u001b[39m'\u001b[39m\u001b[39mpaddle.disable_static()\u001b[39m\u001b[39m'\u001b[39m\u001b[39m to enter dynamic graph mode.\u001b[39m\u001b[39m\"\u001b[39m\n\u001b[1;32m    532\u001b[0m         \u001b[39m%\u001b[39m func\u001b[39m.\u001b[39m\u001b[39m__name__\u001b[39m\n\u001b[1;32m    533\u001b[0m     )\n\u001b[0;32m--> 534\u001b[0m     \u001b[39mreturn\u001b[39;00m func(\u001b[39m*\u001b[39;49margs, \u001b[39m*\u001b[39;49m\u001b[39m*\u001b[39;49mkwargs)\n",
      "File \u001b[0;32m~/opt/miniconda3/lib/python3.9/site-packages/paddle/fluid/dygraph/varbase_patch_methods.py:297\u001b[0m, in \u001b[0;36mmonkey_patch_varbase.<locals>.backward\u001b[0;34m(self, grad_tensor, retain_graph)\u001b[0m\n\u001b[1;32m    295\u001b[0m \u001b[39melse\u001b[39;00m:\n\u001b[1;32m    296\u001b[0m     \u001b[39mif\u001b[39;00m framework\u001b[39m.\u001b[39m_in_eager_mode_:\n\u001b[0;32m--> 297\u001b[0m         core\u001b[39m.\u001b[39;49meager\u001b[39m.\u001b[39;49mrun_backward([\u001b[39mself\u001b[39;49m], grad_tensor, retain_graph)\n\u001b[1;32m    298\u001b[0m     \u001b[39melse\u001b[39;00m:\n\u001b[1;32m    299\u001b[0m         core\u001b[39m.\u001b[39mdygraph_run_backward([\u001b[39mself\u001b[39m], [grad_tensor],\n\u001b[1;32m    300\u001b[0m                                   retain_graph,\n\u001b[1;32m    301\u001b[0m                                   framework\u001b[39m.\u001b[39m_dygraph_tracer())\n",
      "\u001b[0;31mKeyboardInterrupt\u001b[0m: "
     ]
    }
   ],
   "source": [
    "trainer = EmbeddingTrainer(model, args, (optimizer, scheduler), viz)\n",
    "trainer.train(training_loader, testing_loader)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "nUouTmv3Spye"
   },
   "source": [
    "## Visualization of Results"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "N-EsF8GUSqmQ"
   },
   "source": [
    "Embed some test predictions of the embedding model, this is simply showing the prediction of the model of one Koopman step. I.e. x(t+1) = decoder(K*encoder(x(t))). Random time-steps are plotted."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 1000
    },
    "id": "w7J5GgEFh_8t",
    "outputId": "c051b0ea-b8d3-405d-ba87-ad0c62184b0a"
   },
   "outputs": [],
   "source": [
    "from IPython.display import Image, display\n",
    "\n",
    "for epoch in [1, 25, 50, 75, 100]:\n",
    "  print('Validation prediction for epoch: {:d}'.format(epoch))\n",
    "  file_path = './outputs/embedding_cylinder/ntrain27_epochs100_batch32/predictions/embeddingPred0_{:d}.png'.format(epoch)\n",
    "  display(Image(file_path, width=600, height=300))"
   ]
  }
 ],
 "metadata": {
  "accelerator": "GPU",
  "colab": {
   "collapsed_sections": [],
   "name": "train_cylinder_enn.ipynb",
   "provenance": []
  },
  "kernelspec": {
   "display_name": "base",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.12"
  },
  "vscode": {
   "interpreter": {
    "hash": "db64661fe0d122184d8f0ece1104b953994d7acbc475dabbdd4eb4bc24907e06"
   }
  }
 },
 "nbformat": 4,
 "nbformat_minor": 0
}
